{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training a causal language model from scratch (PyTorch)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Install the Transformers, Datasets, and Evaluate libraries to run this notebook."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install datasets evaluate transformers[sentencepiece]\n",
    "!pip install accelerate\n",
    "# To run the training on TPU, you will need to uncomment the following line:\n",
    "# !pip install cloud-tpu-client==0.10 torch==1.9.0 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.9-cp37-cp37m-linux_x86_64.whl\n",
    "!apt install git-lfs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You will need to setup git, adapt your email and name in the following cell."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!git config --global user.email \"you@example.com\"\n",
    "!git config --global user.name \"Your Name\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You will also need to be logged in to the Hugging Face Hub. Execute the following and enter your credentials."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from huggingface_hub import notebook_login\n",
    "\n",
    "notebook_login()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def any_keyword_in_string(string, keywords):\n",
    "    for keyword in keywords:\n",
    "        if keyword in string:\n",
    "            return True\n",
    "    return False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "False True"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "filters = [\"pandas\", \"sklearn\", \"matplotlib\", \"seaborn\"]\n",
    "example_1 = \"import numpy as np\"\n",
    "example_2 = \"import pandas as pd\"\n",
    "\n",
    "print(\n",
    "    any_keyword_in_string(example_1, filters), any_keyword_in_string(example_2, filters)\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import defaultdict\n",
    "from tqdm import tqdm\n",
    "from datasets import Dataset\n",
    "\n",
    "\n",
    "def filter_streaming_dataset(dataset, filters):\n",
    "    filtered_dict = defaultdict(list)\n",
    "    total = 0\n",
    "    for sample in tqdm(iter(dataset)):\n",
    "        total += 1\n",
    "        if any_keyword_in_string(sample[\"content\"], filters):\n",
    "            for k, v in sample.items():\n",
    "                filtered_dict[k].append(v)\n",
    "    print(f\"{len(filtered_dict['content'])/total:.2%} of data after filtering.\")\n",
    "    return Dataset.from_dict(filtered_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "3.26% of data after filtering."
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# This cell will take a very long time to execute, so you should skip it and go to\n",
    "# the next one!\n",
    "from datasets import load_dataset\n",
    "\n",
    "split = \"train\"  # \"valid\"\n",
    "filters = [\"pandas\", \"sklearn\", \"matplotlib\", \"seaborn\"]\n",
    "\n",
    "data = load_dataset(f\"transformersbook/codeparrot-{split}\", split=split, streaming=True)\n",
    "filtered_data = filter_streaming_dataset(data, filters)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['repo_name', 'path', 'copies', 'size', 'content', 'license'],\n",
       "        num_rows: 606720\n",
       "    })\n",
       "    valid: Dataset({\n",
       "        features: ['repo_name', 'path', 'copies', 'size', 'content', 'license'],\n",
       "        num_rows: 3322\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from datasets import load_dataset, DatasetDict\n",
    "\n",
    "ds_train = load_dataset(\"huggingface-course/codeparrot-ds-train\", split=\"train\")\n",
    "ds_valid = load_dataset(\"huggingface-course/codeparrot-ds-valid\", split=\"validation\")\n",
    "\n",
    "raw_datasets = DatasetDict(\n",
    "    {\n",
    "        \"train\": ds_train,  # .shuffle().select(range(50000)),\n",
    "        \"valid\": ds_valid,  # .shuffle().select(range(500))\n",
    "    }\n",
    ")\n",
    "\n",
    "raw_datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'REPO_NAME: kmike/scikit-learn'\n",
       "'PATH: sklearn/utils/__init__.py'\n",
       "'COPIES: 3'\n",
       "'SIZE: 10094'\n",
       "'''CONTENT: \"\"\"\n",
       "The :mod:`sklearn.utils` module includes various utilites.\n",
       "\"\"\"\n",
       "\n",
       "from collections import Sequence\n",
       "\n",
       "import numpy as np\n",
       "from scipy.sparse import issparse\n",
       "import warnings\n",
       "\n",
       "from .murmurhash import murm\n",
       "LICENSE: bsd-3-clause'''"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "for key in raw_datasets[\"train\"][0]:\n",
    "    print(f\"{key.upper()}: {raw_datasets['train'][0][key][:200]}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Input IDs length: 34\n",
       "Input chunk lengths: [128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 117, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 41]\n",
       "Chunk mapping: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from transformers import AutoTokenizer\n",
    "\n",
    "context_length = 128\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"huggingface-course/code-search-net-tokenizer\")\n",
    "\n",
    "outputs = tokenizer(\n",
    "    raw_datasets[\"train\"][:2][\"content\"],\n",
    "    truncation=True,\n",
    "    max_length=context_length,\n",
    "    return_overflowing_tokens=True,\n",
    "    return_length=True,\n",
    ")\n",
    "\n",
    "print(f\"Input IDs length: {len(outputs['input_ids'])}\")\n",
    "print(f\"Input chunk lengths: {(outputs['length'])}\")\n",
    "print(f\"Chunk mapping: {outputs['overflow_to_sample_mapping']}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['input_ids'],\n",
       "        num_rows: 16702061\n",
       "    })\n",
       "    valid: Dataset({\n",
       "        features: ['input_ids'],\n",
       "        num_rows: 93164\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def tokenize(element):\n",
    "    outputs = tokenizer(\n",
    "        element[\"content\"],\n",
    "        truncation=True,\n",
    "        max_length=context_length,\n",
    "        return_overflowing_tokens=True,\n",
    "        return_length=True,\n",
    "    )\n",
    "    input_batch = []\n",
    "    for length, input_ids in zip(outputs[\"length\"], outputs[\"input_ids\"]):\n",
    "        if length == context_length:\n",
    "            input_batch.append(input_ids)\n",
    "    return {\"input_ids\": input_batch}\n",
    "\n",
    "\n",
    "tokenized_datasets = raw_datasets.map(\n",
    "    tokenize, batched=True, remove_columns=raw_datasets[\"train\"].column_names\n",
    ")\n",
    "tokenized_datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer, GPT2LMHeadModel, AutoConfig\n",
    "\n",
    "config = AutoConfig.from_pretrained(\n",
    "    \"gpt2\",\n",
    "    vocab_size=len(tokenizer),\n",
    "    n_ctx=context_length,\n",
    "    bos_token_id=tokenizer.bos_token_id,\n",
    "    eos_token_id=tokenizer.eos_token_id,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "GPT-2 size: 124.2M parameters"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = GPT2LMHeadModel(config)\n",
    "model_size = sum(t.numel() for t in model.parameters())\n",
    "print(f\"GPT-2 size: {model_size/1000**2:.1f}M parameters\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import DataCollatorForLanguageModeling\n",
    "\n",
    "tokenizer.pad_token = tokenizer.eos_token\n",
    "data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "input_ids shape: torch.Size([5, 128])\n",
       "attention_mask shape: torch.Size([5, 128])\n",
       "labels shape: torch.Size([5, 128])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "out = data_collator([tokenized_datasets[\"train\"][i] for i in range(5)])\n",
    "for key in out:\n",
    "    print(f\"{key} shape: {out[key].shape}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from huggingface_hub import notebook_login\n",
    "\n",
    "notebook_login()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import Trainer, TrainingArguments\n",
    "\n",
    "args = TrainingArguments(\n",
    "    output_dir=\"codeparrot-ds\",\n",
    "    per_device_train_batch_size=32,\n",
    "    per_device_eval_batch_size=32,\n",
    "    evaluation_strategy=\"steps\",\n",
    "    eval_steps=5_000,\n",
    "    logging_steps=5_000,\n",
    "    gradient_accumulation_steps=8,\n",
    "    num_train_epochs=1,\n",
    "    weight_decay=0.1,\n",
    "    warmup_steps=1_000,\n",
    "    lr_scheduler_type=\"cosine\",\n",
    "    learning_rate=5e-4,\n",
    "    save_steps=5_000,\n",
    "    fp16=True,\n",
    "    push_to_hub=True,\n",
    ")\n",
    "\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    tokenizer=tokenizer,\n",
    "    args=args,\n",
    "    data_collator=data_collator,\n",
    "    train_dataset=tokenized_datasets[\"train\"],\n",
    "    eval_dataset=tokenized_datasets[\"valid\"],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.push_to_hub()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from transformers import pipeline\n",
    "\n",
    "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
    "pipe = pipeline(\n",
    "    \"text-generation\", model=\"huggingface-course/codeparrot-ds\", device=device\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "# create some data\n",
       "x = np.random.randn(100)\n",
       "y = np.random.randn(100)\n",
       "\n",
       "# create scatter plot with x, y\n",
       "plt.scatter(x, y)\n",
       "\n",
       "# create scatter"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "txt = \"\"\"\\\n",
    "# create some data\n",
    "x = np.random.randn(100)\n",
    "y = np.random.randn(100)\n",
    "\n",
    "# create scatter plot with x, y\n",
    "\"\"\"\n",
    "print(pipe(txt, num_return_sequences=1)[0][\"generated_text\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "# create some data\n",
       "x = np.random.randn(100)\n",
       "y = np.random.randn(100)\n",
       "\n",
       "# create dataframe from x and y\n",
       "df = pd.DataFrame({'x': x, 'y': y})\n",
       "df.insert(0,'x', x)\n",
       "for"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "txt = \"\"\"\\\n",
    "# create some data\n",
    "x = np.random.randn(100)\n",
    "y = np.random.randn(100)\n",
    "\n",
    "# create dataframe from x and y\n",
    "\"\"\"\n",
    "print(pipe(txt, num_return_sequences=1)[0][\"generated_text\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "# dataframe with profession, income and name\n",
       "df = pd.DataFrame({'profession': x, 'income':y, 'name': z})\n",
       "\n",
       "# calculate the mean income per profession\n",
       "profession = df.groupby(['profession']).mean()\n",
       "\n",
       "# compute the"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "txt = \"\"\"\\\n",
    "# dataframe with profession, income and name\n",
    "df = pd.DataFrame({'profession': x, 'income':y, 'name': z})\n",
    "\n",
    "# calculate the mean income per profession\n",
    "\"\"\"\n",
    "print(pipe(txt, num_return_sequences=1)[0][\"generated_text\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "# import random forest regressor from scikit-learn\n",
       "from sklearn.ensemble import RandomForestRegressor\n",
       "\n",
       "# fit random forest model with 300 estimators on X, y:\n",
       "rf = RandomForestRegressor(n_estimators=300, random_state=random_state, max_depth=3)\n",
       "rf.fit(X, y)\n",
       "rf"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "txt = \"\"\"\n",
    "# import random forest regressor from scikit-learn\n",
    "from sklearn.ensemble import RandomForestRegressor\n",
    "\n",
    "# fit random forest model with 300 estimators on X, y:\n",
    "\"\"\"\n",
    "print(pipe(txt, num_return_sequences=1)[0][\"generated_text\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Keyword has not single token: testtest'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "keytoken_ids = []\n",
    "for keyword in [\n",
    "    \"plt\",\n",
    "    \"pd\",\n",
    "    \"sk\",\n",
    "    \"fit\",\n",
    "    \"predict\",\n",
    "    \" plt\",\n",
    "    \" pd\",\n",
    "    \" sk\",\n",
    "    \" fit\",\n",
    "    \" predict\",\n",
    "    \"testtest\",\n",
    "]:\n",
    "    ids = tokenizer([keyword]).input_ids[0]\n",
    "    if len(ids) == 1:\n",
    "        keytoken_ids.append(ids[0])\n",
    "    else:\n",
    "        print(f\"Keyword has not single token: {keyword}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.nn import CrossEntropyLoss\n",
    "import torch\n",
    "\n",
    "\n",
    "def keytoken_weighted_loss(inputs, logits, keytoken_ids, alpha=1.0):\n",
    "    # Shift so that tokens < n predict n\n",
    "    shift_labels = inputs[..., 1:].contiguous()\n",
    "    shift_logits = logits[..., :-1, :].contiguous()\n",
    "    # Calculate per-token loss\n",
    "    loss_fct = CrossEntropyLoss(reduce=False)\n",
    "    loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))\n",
    "    # Resize and average loss per sample\n",
    "    loss_per_sample = loss.view(shift_logits.size(0), shift_logits.size(1)).mean(axis=1)\n",
    "    # Calculate and scale weighting\n",
    "    weights = torch.stack([(inputs == kt).float() for kt in keytoken_ids]).sum(\n",
    "        axis=[0, 2]\n",
    "    )\n",
    "    weights = alpha * (1.0 + weights)\n",
    "    # Calculate weighted average\n",
    "    weighted_loss = (loss_per_sample * weights).mean()\n",
    "    return weighted_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data.dataloader import DataLoader\n",
    "\n",
    "tokenized_dataset.set_format(\"torch\")\n",
    "train_dataloader = DataLoader(tokenized_dataset[\"train\"], batch_size=32, shuffle=True)\n",
    "eval_dataloader = DataLoader(tokenized_dataset[\"valid\"], batch_size=32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "weight_decay = 0.1\n",
    "\n",
    "\n",
    "def get_grouped_params(model, no_decay=[\"bias\", \"LayerNorm.weight\"]):\n",
    "    params_with_wd, params_without_wd = [], []\n",
    "    for n, p in model.named_parameters():\n",
    "        if any(nd in n for nd in no_decay):\n",
    "            params_without_wd.append(p)\n",
    "        else:\n",
    "            params_with_wd.append(p)\n",
    "    return [\n",
    "        {\"params\": params_with_wd, \"weight_decay\": weight_decay},\n",
    "        {\"params\": params_without_wd, \"weight_decay\": 0.0},\n",
    "    ]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate():\n",
    "    model.eval()\n",
    "    losses = []\n",
    "    for step, batch in enumerate(eval_dataloader):\n",
    "        with torch.no_grad():\n",
    "            outputs = model(batch[\"input_ids\"], labels=batch[\"input_ids\"])\n",
    "\n",
    "        losses.append(accelerator.gather(outputs.loss))\n",
    "    loss = torch.mean(torch.cat(losses))\n",
    "    try:\n",
    "        perplexity = torch.exp(loss)\n",
    "    except OverflowError:\n",
    "        perplexity = float(\"inf\")\n",
    "    return loss.item(), perplexity.item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = GPT2LMHeadModel(config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.optim import AdamW\n",
    "\n",
    "optimizer = AdamW(get_grouped_params(model), lr=5e-4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from accelerate import Accelerator\n",
    "\n",
    "accelerator = Accelerator(fp16=True)\n",
    "\n",
    "model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(\n",
    "    model, optimizer, train_dataloader, eval_dataloader\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import get_scheduler\n",
    "\n",
    "num_train_epochs = 1\n",
    "num_update_steps_per_epoch = len(train_dataloader)\n",
    "num_training_steps = num_train_epochs * num_update_steps_per_epoch\n",
    "\n",
    "lr_scheduler = get_scheduler(\n",
    "    name=\"linear\",\n",
    "    optimizer=optimizer,\n",
    "    num_warmup_steps=1_000,\n",
    "    num_training_steps=num_training_steps,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'sgugger/codeparrot-ds-accelerate'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from huggingface_hub import Repository, get_full_repo_name\n",
    "\n",
    "model_name = \"codeparrot-ds-accelerate\"\n",
    "repo_name = get_full_repo_name(model_name)\n",
    "repo_name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "output_dir = \"codeparrot-ds-accelerate\"\n",
    "repo = Repository(output_dir, clone_from=repo_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(10.934126853942871, 56057.14453125)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "evaluate()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm.notebook import tqdm\n",
    "\n",
    "gradient_accumulation_steps = 8\n",
    "eval_steps = 5_000\n",
    "\n",
    "model.train()\n",
    "completed_steps = 0\n",
    "for epoch in range(num_train_epochs):\n",
    "    for step, batch in tqdm(\n",
    "        enumerate(train_dataloader, start=1), total=num_training_steps\n",
    "    ):\n",
    "        logits = model(batch[\"input_ids\"]).logits\n",
    "        loss = keytoken_weighted_loss(batch[\"input_ids\"], logits, keytoken_ids)\n",
    "        if step % 100 == 0:\n",
    "            accelerator.print(\n",
    "                {\n",
    "                    \"lr\": get_lr(),\n",
    "                    \"samples\": step * samples_per_step,\n",
    "                    \"steps\": completed_steps,\n",
    "                    \"loss/train\": loss.item() * gradient_accumulation_steps,\n",
    "                }\n",
    "            )\n",
    "        loss = loss / gradient_accumulation_steps\n",
    "        accelerator.backward(loss)\n",
    "        if step % gradient_accumulation_steps == 0:\n",
    "            accelerator.clip_grad_norm_(model.parameters(), 1.0)\n",
    "            optimizer.step()\n",
    "            lr_scheduler.step()\n",
    "            optimizer.zero_grad()\n",
    "            completed_steps += 1\n",
    "        if (step % (eval_steps * gradient_accumulation_steps)) == 0:\n",
    "            eval_loss, perplexity = evaluate()\n",
    "            accelerator.print({\"loss/eval\": eval_loss, \"perplexity\": perplexity})\n",
    "            model.train()\n",
    "            accelerator.wait_for_everyone()\n",
    "            unwrapped_model = accelerator.unwrap_model(model)\n",
    "            unwrapped_model.save_pretrained(output_dir, save_function=accelerator.save)\n",
    "            if accelerator.is_main_process:\n",
    "                tokenizer.save_pretrained(output_dir)\n",
    "                repo.push_to_hub(\n",
    "                    commit_message=f\"Training in progress step {step}\", blocking=False\n",
    "                )"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "name": "Training a causal language model from scratch (PyTorch)",
   "provenance": []
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
